{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert And Inference Pytorch model with CustomOps\n",
    "\n",
    "This notebook demonstrates how to use onnxruntime-extensions to run a PyTorch model that contains operators that are not part of the ONNX standard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model definition and export to ONNX\n",
    "Suppose there is a model that cannot be converted because there is no matrix inverse operation in ONNX standard opset. And the model will be defined like the following."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "class CustomInverse(torch.nn.Module):\n",
    "    def forward(self, x):\n",
    "        return torch.inverse(x) + x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To export this model into ONNX format, we need register a custom op handler for pytorch.onnx.exporter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.onnx import register_custom_op_symbolic\n",
    "\n",
    "\n",
    "def my_inverse(g, self):\n",
    "    return g.op(\"ai.onnx.contrib::Inverse\", self)\n",
    "\n",
    "register_custom_op_symbolic('::inverse', my_inverse, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, invoke the exporter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import io\n",
    "import onnx\n",
    "\n",
    "x0 = torch.randn(3, 3)\n",
    "# Export model to ONNX\n",
    "f = io.BytesIO()\n",
    "t_model = CustomInverse()\n",
    "torch.onnx.export(t_model, (x0, ), f, opset_version=12)\n",
    "onnx_model = onnx.load(io.BytesIO(f.getvalue()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we got a ONNX model in the memory, and it can be save into a disk file by 'onnx.save_model(onnx_model, <file_path>)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "This converted model cannot directly run the onnxruntime due to the custom operator. but it can run with onnxruntime_extensions easily.\n",
    "\n",
    "Firstly, let define a PyOp function to intepret the custom op node in the ONNX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "from onnxruntime_extensions import onnx_op, PyOp\n",
    "@onnx_op(op_type=\"Inverse\")\n",
    "def inverse(x):\n",
    "    # the user custom op implementation here:\n",
    "    return numpy.linalg.inv(x)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* **ONNX Inference**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-3.081008    0.20269153  0.42009977]\n",
      " [-3.3962293   2.5986686   2.4447646 ]\n",
      " [ 0.7805753  -0.20394287 -2.7528977 ]]\n"
     ]
    }
   ],
   "source": [
    "from onnxruntime_extensions import PyOrtFunction\n",
    "onnx_fn = PyOrtFunction.from_model(onnx_model)\n",
    "y = onnx_fn(x0.numpy())\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* **Compare the result with Pytorch**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_y = t_model(x0)\n",
    "numpy.testing.assert_almost_equal(t_y, y, decimal=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implement the customop in C++ (optional)\n",
    "To make the ONNX model with the CustomOp run on all other languages supported by the ONNX Runtime and be independent of Python, a C++ implementation is needed, check [inverse.hpp](../operators/math/dlib/inverse.hpp) for an example on how to do that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-3.081008    0.20269153  0.42009977]\n",
      " [-3.3962293   2.5986686   2.4447646 ]\n",
      " [ 0.7805753  -0.20394287 -2.7528977 ]]\n"
     ]
    }
   ],
   "source": [
    "from onnxruntime_extensions import enable_py_op\n",
    "# disable the PyOp function and run with the C++ function\n",
    "enable_py_op(False)\n",
    "y = onnx_fn(x0.numpy())\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
